{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "29888eda-2755-4add-af9c-8927afb07db4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "from tokenizers import Tokenizer\n",
    "#tokenizer = AutoTokenizer.from_pretrained('human_gpt2-v1')\n",
    "#然后我们可以使用from_file() 方法从该文件里重新加载 Tokenizer 对象：\n",
    "# #然后我们可以使用from_file() 方法从该文件里重新加载 Tokenizer 对象：\n",
    "new_tokenizer = Tokenizer.from_file(\"tokenizer8.json\")\n",
    "# #或者下面方法\n",
    "# from transformers import GPT2TokenizerFast\n",
    "# tokenizer = GPT2TokenizerFast(tokenizer_object=new_tokenizer)\n",
    "\n",
    "from transformers import PreTrainedTokenizerFast\n",
    "\n",
    "tokenizer = PreTrainedTokenizerFast(\n",
    "    tokenizer_object=new_tokenizer,\n",
    "    bos_token=\"<s>\",\n",
    "    eos_token=\"</s>\",\n",
    "    unk_token=\"<unk>\",\n",
    "    pad_token=\"<pad>\",\n",
    "    cls_token=\"<cls>\",\n",
    "    sep_token=\"<sep>\",\n",
    "    mask_token=\"<mask>\",\n",
    "    padding_side=\"left\",\n",
    ")\n",
    "#tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "\n",
    "# if tokenizer.pad_token is None:\n",
    "#     tokenizer.add_special_tokens({'pad_token': '[PAD]'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75e3778d-9cbb-48c4-8203-0f71f485a49d",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model = AutoModel.from_pretrained('mygpt_unigram8')\n",
    "#model.config.eos_token_id\n",
    "# print(model.config.pad_token_id)\n",
    "#model.config.pad_token_id = model.config.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be845310-8215-4434-bb92-d44c343f274e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transformers.models.gpt2.modeling_gpt2\n"
     ]
    }
   ],
   "source": [
    "gena_module_name = full_model.__class__.__module__\n",
    "print(gena_module_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bb9ee647-cac2-4475-ac44-2f11aef99735",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import importlib\n",
    "# available class names:\n",
    "# - BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction,\n",
    "# - BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification,\n",
    "# - BertForQuestionAnswering\n",
    "# check https://huggingface.co/docs/transformers/model_doc/bert\n",
    "myclass = importlib.import_module(gena_module_name)\n",
    "#dir(myclass)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "62ebfcc0-5d46-4c3b-b462-a3842a985e9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cls = getattr(importlib.import_module(gena_module_name), 'GPT2ForSequenceClassification')\n",
    "cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4523e9b6-3613-41e4-b81d-c139d044e70c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at mygpt_unigram8 and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = cls.from_pretrained('mygpt_unigram8', num_labels=2)\n",
    "#dir(model)\n",
    "#model.config.eos_token_id\n",
    "#print(model.config.pad_token_id)\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "#model.resize_token_embeddings(len(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bd024199-16e6-4eb1-b404-c49dc535469b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "# load ~11k samples from promoters prediction dataset\n",
    "dataset = load_dataset(\"yurakuratov/example_promoters_2k\")['train'].train_test_split(test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3a5213d-66f0-4d17-876c-de4426145412",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sequence', 'promoter_presence'],\n",
       "        num_rows: 10656\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sequence', 'promoter_presence'],\n",
       "        num_rows: 1184\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "36a23870-1f56-466e-b2bb-c5047072b8f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'sequence': 'GTGGTGGCTCACGCCTGTAATCCCAGCACAGTGGGAGGCTAAGGCGGGCAGATCACTCGAGGTGAGGAATTTGAGACCAGCCTGGCCAACATGGTGAAACAGTCACAACGACCTCAGGCTTGTTTTCTTCGCTTGGATCCAGCTGCCCTGCCCCCAGTCACTCCCAGAATTTCAGGGCGCGCCCCGTGGCCCCGCCTCTCGGAGGGGGAGGTTTCAGTGAGCCGAGACTGCGCCATTGCACTCCAGCCTGGGCAACAAGAGCGAAACTCCGTCTCGGAAAAAAAAAAAAAAGAAAAAAACTAAAATATCAGTACTCTTTTCACTTAGTCATCACAAACCCCTTTAATCCCATATTACCCTCTACTATCATGGTAGCCTCACCTTTGACGACAGACTAGCTGTGTGACCCAAGTAAATGACTTCAGTGCCACCACCTGTAGGAAGGGGCTGCTAATAGCATCTGCTTCATACTAACGTGTTTTCAGGATCGATCTGCATATAAACGCTCAATACACGGTAACTGATTGTCACATTCTATCCCAGGACGCTTCCATTTGCGGATCTCATGGCCCGGAGGACCTGGGCCGCCTCACAACTGGGTCTGAGTCTTGCGTGGGTCCTCTATATAGGGTGAGAAGCGTGGCGCTCGGTTCCTGCCTCGGGGAAGTCCTGGCGCAGATGGGCCACGGGGCCGGCGTGGAGCTGGAGGCAGATCCCGCCCTTGCAGGGCTGGGTAAGAGACTGAAACTTGGAGAATGAACTGCAGGCGCTGCACATCTCGGGCAGATTTAAAATTCGGTCAGCCCCGCTCACCGACACAAAGCCCACAACCGCGCCCAATAGGGAATCGGCCCTGCGGCCCTCCTGCCACTCGCCCCCAACTATCACAGTCCACGGCGGCCCGTCTCTACTAAAAATACAAAAATTAGCCGGGCGTGGTGGTGGGCGCTTGTAATCCCGGCTACTCAGGAGGCTGAGGCAGGAGAATCGCTTGAACCCGGTCTCCAATTGGTCCCTCGCTCATCCAGTCCCCGTAGGCCAATCTTCCCTTCTCATTGGTAGGCGGTGAGTGTGTGGAGAAACCTGACCTAATGGGCGCTGGAGTTCGCAAAACGTGACCCGGAAATTGGTCACGAGCAGGCGCCGTGGGCTTGTGGACGCCTAACTTGCGCGCTGAGATTTCCGGCGTGGGAGCAGAGGTCCGGGTCTCACCCCGACACTGACGCACTGGAGAGCGCGTGGCCCTGGGATCTGCGTTTTCCCAAGCGCTCCGGAGTAATTCTAACCCAGGTTGCCCGCCCTAAAATCGGAAGCAGGTGCCGCGAACCCAGACCCGAGCGCTCCTCTGGGACGTCCAGCCGCAGAAGCGAGCCCAAAGCCCAGCAGGGTTCTCCGCGGGAAGCAGGCCACGCGGCGCCACTGACTGGGCGGGCGGCTAGCTCCCTGGCCTCCCCGACATGCCGAGGAGCCCCTGGCTTCCGGAGCCGACGAGGCCGCAGGCTCGGGGTTTCAGAACCGCCGGGCTCTTGGGCGCTCGCCCAGGGCGCACGCGCAGTCGCGAGGCCGCGCCGCGGACTGCATCTCCCAGCAGGCCCCGCGCGGCCAATGGGACCATCTTCAGCCAGCCCAGAGTAGCCGAAGCGGGTGGGGCCTGGCTTGGGAAAGGATTCATCCGGGGAAGCTGATTAACAATTCAGATTCGGCGCCTGGGACCGACTGAGGCCTAGGCGCCGGAGCCGGCCGCGCCTGGGCTGGAGCGGGGCTCCTCGGCCTGGACTGGGAGCCCCCGGCCCCGGGCTCCTGCTGGCGCCGTCCAACCTTACATGGGTTCAGGGCGCCTTCGTAGGCGGGCACGGCTGGTTTCGGGCTAAGGCGCTCTGGAGACCTGACGATGGCGTCGGGCCCGGGCTCCCAGGAACGGGAAGGGCTCCTGATAGTGAAGCTGGAGGAGGACTGCGCCTGGAGCCAGGAGCTGCCCCCACCTGACCCAGGACCGAGCC',\n",
       " 'promoter_presence': 0}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "19e22500-c1e0-444e-9b5c-2a1a2af81997",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# base pairs:  2000\n"
     ]
    }
   ],
   "source": [
    "print('# base pairs: ', len(dataset['train'][0]['sequence']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "16db0a5d-7aec-4d59-b44b-d22ff43045fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokens:  GTGG TGGCTCACGCCT GTAATCCCAGCA CAGTG GGAGGCTAAGGC GGGCAG ATCACTC GAGGTGAGG AATTTGAG ACCAGCCTGGCC AACATGGTGAAA CAGTCACA ACGA CCTCAGGCT TGTTTTCTT CG CTTGGA TCCAGCT GCCCTGCCCCCA GTCACT CCCAGAA TTTCAGG GCGCGCCC CGTGG CCCCGCCTC TCG GAGGGGGAGG TTTCAGT GAGCCGAGA CTG CGCCATTGCACT CCAGCCTGGGCA ACAAGAGCGAAA CTCCGTCTC GG AAAAAAAAAAAA AAGAAAAAAA CTAAAAT ATCAGTA CTC TTTTCACTT AGTCATCA CAAACC CCTTTAA TCCCATA TTACCCT CTACTAT CATGGTAG CCTCACC TTTGAC GACA GACTAGCTG TGTGACCC AAGTAAATGA CTTCAGTG CCACCACC TGTAGGA AGGGGCTGC TAATAG CATCTGCTT CATACTA ACGTGT TTTCAGGAT CGA TCTGCATAT AAAC GCTCA ATACAC GGTAACTG ATTGTCA CATTCTAT CCCAGGAC GCTT CCATTTGC GGATC TCATGGCCC GGAGGA CCTGGGCC GCCTCAC AACTGGG TCTGAGTCT TGCGTG GGTCCTC TATATAG GGTGAGAA GCGTGG CGCTC GGTTCCTG CCTCGGG GAAGT CCTGGC GCAGATG GGCCAC GGGGCC GGCGTGG AGCTG GAGGCAG ATCCC GCCCTT GCAGGGCTGGG TAAGAG ACTGAAA CTTGGAGAA TGAA CTGCAGGC GCTG CACATCTC GGGCAG ATTTAAAA TTCGGT CAGCCCCG CTCACC GA CACAAAG CCCACAA CCGCGCC CAATAG GGAATC GGCCCTGC GGCC CTCCTG CCACTCG CCCCCAAC TATCACAG TCCAC GGCG GCCCGTCTC TACTAAAAATAC AAAAATTAGCCG GGCGTGGTGGTG GGCG CTTGTAAT CCCGGCTA CTCA GGAGGCTGAGGC AGGAGAATCGCT TGAACCCGG TCTCCAA TTGG TCCCTC GCTCA TCCAGT CCCCG TAGGCCAAT CTTCCCTTCTC ATTGGT AGGCG GTGAGTG TGTGGAGAAA CCTGACCT AAT GGGCGCTG GAGTTC GCAAAAC GTGACCC GGAAAT TGGTCAC GAGCA GGCGCCG TGGGCT TGTGGAC GCCT AACTTG CGCGC TGAGA TTTCC GGCGTG GGAGCAGAGG TCCGGG TCTCA CCCCGAC ACTGAC GCACT GGAGAGC GCGTGGCC CTGGGAT CTGCGTT TTCCCAAG CGCTCC GGAG TAATTCT AACCCAGG TTG CCCGCC CTAAAATC GGAAG CAGGTGCC GC GAA CCCAGACCC GAGCGC TCCTCTGGG ACGT CCAGCC GCAGAAG CGAG CCCAAAG CCCAGCA GGGTT CTCCGC GGGAAGCAG GCCA CGCGGCG CCACTGA CTG GGCGGGCGG CTAGCT CCCTGGCCT CCCCGAC ATGCC GAG GAGCCCCTGG CTTCC GGAGCC GACGAGG CCGCAGG CTCGGGG TTT CAGAACC GCCGGGC TCTTGG GCGCTCGC CCAGGGC GCACGCGC AGTC GCGAGGCC G CGCCGCGG ACTGCAT CTCCCAG CAGGCCCC GCGCGGCC AATGGGA CCATCTT CAGCCAGC CCAGAGT AGCC GAAGCGGG TGGGGCCTGG CTTGGGAA AGGATTCA TCCG GGGAAGCTG ATTAACAA TTCAGATT CGGCGCC TGGGACC GACTGAG GCCTAGGC GCC GGAGCC GGCCGCG CCTGGGCT GGAGCGGG GCTCCTC GGCCTGGA CTGGGAGCC CCCGGCC CCGGGC TCCTGCT GGCGCCG TCCA ACCTTACA TGGGTT CAGGGC GCCTTC GTAGGC GGGCACGG CTGGTTT CGGG CTAAGGC GCTCTGG AGACCTG ACGAT GGCGTC GGGCCC GGGCTCCC AGGAAC GGGAAGGG CTCCTG ATAGT GAAGCTGGA GGAGGA CTGCGCCTG GAGCCAGG AG CTGCCCCCA CCTGAC CCAGGACC GAGCC\n"
     ]
    }
   ],
   "source": [
    "print('tokens: ', ' '.join(tokenizer.tokenize(dataset['train'][0]['sequence'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9876c07a-6332-40cc-aa84-2d26e6a3a5ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# tokens:  265\n"
     ]
    }
   ],
   "source": [
    "print('# tokens: ', len(tokenizer.tokenize(dataset['train'][100]['sequence'])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7c4b0442-5fdf-4845-8bf2-4c3b6cc86e84",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 10656/10656 [00:00<00:00, 11280.05 examples/s]\n",
      "Map: 100%|██████████| 1184/1184 [00:00<00:00, 10401.66 examples/s]\n"
     ]
    }
   ],
   "source": [
    "def preprocess_labels(example):\n",
    "  example['label'] = example['promoter_presence']\n",
    "  return example\n",
    "\n",
    "dataset = dataset.map(preprocess_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "46e912af-1aab-4ed2-8007-0f3f970b2255",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "  # just truncate right, but for some tasks symmetric truncation from left and right is more reasonable\n",
    "  # set max_length to 128 to make experiments faster\n",
    "  return tokenizer(examples[\"sequence\"], truncation=True, max_length=256) #max_length 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0ef16794-71fd-4404-a51e-b55b5cae9aab",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 10656/10656 [00:02<00:00, 4146.09 examples/s]\n",
      "Map: 100%|██████████| 1184/1184 [00:00<00:00, 4667.25 examples/s]\n"
     ]
    }
   ],
   "source": [
    "tokenized_dataset = dataset.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "765c199b-9d13-4fd2-ac51-1e5d154766fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "afd6f48f-5524-4740-b8ea-54a79fd79dad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 10656\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['sequence', 'promoter_presence', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "        num_rows: 1184\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e1b708bb-21e4-416d-b191-79f302ca4e93",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "import numpy as np\n",
    "\n",
    "#tokenizer.pad_token = tokenizer.eos_token \n",
    "#tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return {'accuracy': (predictions==labels).sum() / len(labels)}\n",
    "\n",
    "# change training hyperparameters to archive better quality\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"test_run\",\n",
    "    learning_rate=1e-4,\n",
    "    lr_scheduler_type=\"constant_with_warmup\",\n",
    "    warmup_ratio=0.1,\n",
    "    optim='adamw_torch',\n",
    "    weight_decay=0.0,\n",
    "    per_device_train_batch_size=20,\n",
    "    per_device_eval_batch_size=20,\n",
    "    num_train_epochs=5,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    load_best_model_at_end=True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_dataset[\"train\"],\n",
    "    eval_dataset=tokenized_dataset[\"test\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082754e6-53fc-4a95-9d1e-887c42975d71",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1203' max='1335' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1203/1335 07:58 < 00:52, 2.51 it/s, Epoch 4.50/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.669700</td>\n",
       "      <td>0.526868</td>\n",
       "      <td>0.733108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.500200</td>\n",
       "      <td>0.570136</td>\n",
       "      <td>0.710304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.247300</td>\n",
       "      <td>0.852644</td>\n",
       "      <td>0.699324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.113100</td>\n",
       "      <td>0.941890</td>\n",
       "      <td>0.706926</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/liming/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55c9d1eb-6fc6-451f-b596-870ccaa81d8d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
